Skip to content

Add explicit RL micro-batch token cap and fix RL token accounting#2183

Open
taivu1998 wants to merge 2 commits intoPrimeIntellect-ai:mainfrom
taivu1998:fix/1514-rl-micro-batch-max-tokens
Open

Add explicit RL micro-batch token cap and fix RL token accounting#2183
taivu1998 wants to merge 2 commits intoPrimeIntellect-ai:mainfrom
taivu1998:fix/1514-rl-micro-batch-max-tokens

Conversation

@taivu1998
Copy link
Copy Markdown

@taivu1998 taivu1998 commented Apr 2, 2026

Summary

Closes #1514.

This adds an explicit RL trainer control for local micro-batch packing without reintroducing the old SFT-style micro_batch_size semantics.

The new API is trainer.micro_batch_max_tokens, which caps how many text tokens the RL trainer packs into each local micro batch while preserving the existing RL step semantics:

  • one trainer step still performs one optimizer step
  • checkpoint cadence is unchanged
  • weight-broadcast cadence is unchanged
  • fake RL data remains controlled by trainer.data.fake.batch_size

Motivation

Issue #1514 asks for RL-side gradient accumulation control similar to other training paths. In practice, RL already accumulates gradients implicitly across the packed local micro batches that make up one trainer step, but that behavior was only controlled indirectly by packing against model.seq_len.

This change makes that control explicit and safer by:

  • exposing a dedicated RL knob instead of reviving micro_batch_size
  • keeping sample truncation (model.seq_len) separate from local packing capacity (micro_batch_max_tokens)
  • preserving the existing optimizer-step semantics that the async RL pipeline depends on

What Changed

Config

  • Added trainer.micro_batch_max_tokens: int | None
  • Default remains None, which preserves current behavior by falling back to model.seq_len
  • Validation rejects values above model.seq_len
  • Validation also rejects this knob for fake RL data, since fake RL already has its own local micro-batch control via data.fake.batch_size

RL batching and packing

  • Decoupled per-sample truncation from per-micro-batch packing in the RL batch preparation path
  • Threaded micro_batch_max_tokens through the real RL data loader and packers
  • Kept run-isolation behavior intact for multi-run / LoRA packing
  • Added sample_count metadata on packed micro batches so sample accounting remains correct after packing and padding

Trainer accounting and logging

  • Replaced the old RL token-accounting assumption that used the first micro-batch shape as the whole-step token count
  • RL throughput/progress now use the actual packed local token count across all micro batches in the step
  • Added logging/monitor metrics for:
    • local tokens
    • local loss tokens
    • local samples
    • local micro-batch count
    • max packed local micro-batch tokens
    • configured micro_batch_max_tokens

Docs and tests

  • Updated docs that still referenced nonexistent RL micro-batch-size flags
  • Added tests covering:
    • config validation for micro_batch_max_tokens
    • decoupled sample truncation vs packing cap
    • packer behavior when the cap forces more local micro batches

Design Notes

A key goal here is to stay aligned with the current RL architecture instead of importing SFT semantics wholesale.

This PR intentionally does not:

  • add RL micro_batch_size
  • change the meaning of one RL trainer step
  • modify fake-data batching behavior
  • change orchestrator scheduling or checkpoint step behavior

That keeps the implementation small and makes the new knob do exactly one thing: lower per-forward memory pressure by reducing the token budget of each local RL micro batch.

Verification

I added focused unit coverage for the new config and packing behavior.

Local command attempts:

uv run pytest tests/unit/test_configs.py tests/unit/orchestrator/test_batch.py tests/unit/train/rl/test_packer.py -q

On this machine, the repo lockfile only supports Linux environments, so uv test execution was blocked on macOS. I still verified the patch with:

  • python3 -m py_compile on all changed Python files
  • git diff --check
  • manual diff review of the RL config, batching, packing, trainer-accounting, and docs changes

Note

Medium Risk
Touches RL trainer batching, packing, and distributed progress accounting; mistakes could skew metrics or change effective training workload despite validations and tests.

Overview
Adds new RL config trainer.micro_batch_max_tokens to cap how many tokens are packed into each local RL micro-batch (defaulting to model.seq_len), with validation to forbid values above model.seq_len and to reject use with fake RL data.

Threads this cap through the real RL packer/data loader and updates batch preparation to separate per-sample truncation (seq_len) from packing capacity, introducing sample_count on MicroBatch (and dummy padding batches) so sample accounting remains correct after packing.

Fixes RL trainer progress/throughput metrics to use actual packed tokens and samples aggregated across DP ranks (via new rl/stats.py helpers), and logs additional per-step batch stats; updates docs and adds targeted unit tests for config validation, packing behavior, and stats aggregation.

Reviewed by Cursor Bugbot for commit 2c370e6. Bugbot is set up for automated code reviews on this repo. Configure here.

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gradient Accumulation?

1 participant